import numpy as np
import torch
import matplotlib.pyplot as plt
from distributed_pcg.newton_batch import newton, dist_newton_avg, dist_newton_shrinkage, dist_newton_det,  dist_newton_exact, dist_newton_dane
import seaborn as sns

    
def read_dataset(dataset):
    def printdataset(*, name, X):
        print('data set %s loaded, M=%d, d=%d'%(name, X.shape[0], X.shape[1]))
    if dataset == 'cod-rna':
        with open('cod-rna.txt') as f:
            lines = f.readlines()
        M = len(lines) 
        d = len(lines[1].strip().split(' '))-1 
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = ((np.array([int(m.strip().split(' ')[0]) for m in lines])+1)/2+1).astype(int)
    elif dataset == 'svmguide3':
        with open('svmguide3.txt') as f:
            lines = f.readlines()
        M = len(lines) 
        d = len(lines[0].strip().split(' '))-1 
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = ((np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(int)+1)/2+1).astype(int)
    elif dataset == 'abalone':
        with open('abalone.txt') as f:
            lines = f.readlines()
        M = len(lines) 
        d = len(lines[0].strip().split(' '))-1 
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(int)
    elif dataset == 'vehicle':
        with open('vehicle.txt') as f:
            lines = f.readlines()
        M = len(lines) 
        d = len(lines[0].strip().split(' '))-1 
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(int)
    elif dataset == 'bodyfat':
        with open('bodyfat.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 14
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'iris':
        with open('iris.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 4
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'mpg':
        with open('mpg.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 7
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'mg':
        with open('mg.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 6
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'pyrim':
        with open('pyrim.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 27
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'triazines':
        with open('triazines.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 60
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'eunite2001':
        with open('eunite2001.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 16
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'housing':
        with open('housing.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 13
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'cpusmall':
        with open('cpusmall.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 12
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'space_ga':
        with open('space_ga.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 6
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'a1a':
        with open('a1a.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 123
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'australian':
        with open('australian.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 14
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'diabetes':
        with open('diabetes.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 8
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[2:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]))
                raise Exception('pppp')
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'mushrooms':
        with open('mushrooms.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 112
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'splice':
        with open('splice.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 60
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'segment':
        with open('segment.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 19
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'pendigits':
        with open('pendigits.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 16
        X = np.zeros((M,d))
        for i in range(M):
            X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'breast_cancer':
        with open('breast_cancer.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 10
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[2:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1)
                raise Exception('ppp')
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
        Y = Y-3
    elif dataset == 'fourclass':
        with open('fourclass.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 2
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                raise Exception('ppp')
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'heart':
        with open('heart.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 13
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'ionosphere':
        with open('ionosphere.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 34
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[1:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                
        Y = np.array([float(m.strip().split(' ')[0]) for m in lines]).astype(float)
    elif dataset == 'liver-disorders':
        with open('liver-disorders.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 5
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[2:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                
        Y = (np.array([float(m.strip().split(' ')[0]) for m in lines])*2-1).astype(float)
    elif dataset == 'sonar':
        with open('sonar.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 60
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[2:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                
        Y = (np.array([float(m.strip().split(' ')[0]) for m in lines])).astype(float)
    elif dataset == 'svmguide1':
        with open('svmguide1.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 4
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[2:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                
        Y = (np.array([float(m.strip().split(' ')[0]) for m in lines])*2-1).astype(float)
    elif dataset == 'svmguide3':
        with open('svmguide3.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 21
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[2:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                
        Y = (np.array([float(m.strip().split(' ')[0]) for m in lines])).astype(float)
    elif dataset == 'cod-rna':
        with open('cod-rna.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 8
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[2:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                
        Y = (np.array([float(m.strip().split(' ')[0]) for m in lines])).astype(float)
    elif dataset == 'german_numer':
        with open('german_numer.txt') as f:
            lines = f.readlines() 
        M = len(lines) 
        d = 24
        X = np.zeros((M,d))
        for i in range(M):
            try:
                X[i][np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[2:]]).astype(int)-1] = np.array([m.split(':')[1] for m in lines[i].strip().split(' ')[2:]]).astype(float)
            except:
                print(i)
                print(np.array([m.split(':')[0] for m in lines[i].strip().split(' ')[1:]]).astype(int)-1)
                
        Y = (np.array([float(m.strip().split(' ')[0]) for m in lines])).astype(float)
    else:
        raise Exception('invalid dataset')
    printdataset(name=dataset, X=X)
    return M,d,X,Y

